#!/usr/bin/env python3

"""A test case generator for register stackification.

This script exhaustively generates small linear SSA programs, then filters them
based on heuristics designed to keep interesting multivalue test cases and
prints them as LLVM IR functions in a FileCheck test file.

The output of this script is meant to be used in conjunction with
update_llc_test_checks.py.

  ```
  ./multivalue-stackify.py > multivalue-stackify.ll
  ../../../utils/update_llc_test_checks.py multivalue-stackify.ll
  ```

Programs are represented internally as lists of operations, where each operation
is a pair of tuples, the first of which specifies the operation's uses and the
second of which specifies its defs.

TODO: Before embarking on a rewrite of the register stackifier, an abstract
interpreter should be written to automatically check that the test assertions
generated by update_llc_test_checks.py have the same semantics as the functions
generated by this script. Once that is done, exhaustive testing can be done by
making `is_interesting` return True.
"""


from itertools import product
from collections import deque


MAX_PROGRAM_OPS = 4
MAX_PROGRAM_DEFS = 3
MAX_OP_USES = 2


def get_num_defs(program):
  num_defs = 0
  for _, defs in program:
    num_defs += len(defs)
  return num_defs


def possible_ops(program):
  program_defs = get_num_defs(program)
  for num_defs in range(MAX_PROGRAM_DEFS - program_defs + 1):
    for num_uses in range(MAX_OP_USES + 1):
      if num_defs == 0 and num_uses == 0:
        continue
      for uses in product(range(program_defs), repeat=num_uses):
        yield uses, tuple(program_defs + i for i in range(num_defs))


def generate_programs():
  queue = deque()
  queue.append([])
  program_id = 0
  while True:
    program = queue.popleft()
    if len(program) == MAX_PROGRAM_OPS:
      break
    for op in possible_ops(program):
      program_id += 1
      new_program = program + [op]
      queue.append(new_program)
      yield program_id, new_program


def get_num_terminal_ops(program):
  num_terminal_ops = 0
  for _, defs in program:
    if len(defs) == 0:
      num_terminal_ops += 1
  return num_terminal_ops


def get_max_uses(program):
  num_uses = [0] * MAX_PROGRAM_DEFS
  for uses, _ in program:
    for u in uses:
      num_uses[u] += 1
  return max(num_uses)


def has_unused_op(program):
  used = [False] * MAX_PROGRAM_DEFS
  for uses, defs in program[::-1]:
    if defs and all(not used[d] for d in defs):
      return True
    for u in uses:
      used[u] = True
  return False


def has_multivalue_use(program):
  is_multi = [False] * MAX_PROGRAM_DEFS
  for uses, defs in program:
    if any(is_multi[u] for u in uses):
      return True
    if len(defs) >= 2:
      for d in defs:
        is_multi[d] = True
  return False


def has_mvp_use(program):
  is_mvp = [False] * MAX_PROGRAM_DEFS
  for uses, defs in program:
    if uses and all(is_mvp[u] for u in uses):
      return True
    if len(defs) <= 1:
      if any(is_mvp[u] for u in uses):
        return True
      for d in defs:
        is_mvp[d] = True
  return False


def is_interesting(program):
  # Allow only multivalue single-op programs
  if len(program) == 1:
    return len(program[0][1]) > 1

  # Reject programs where the last two instructions are identical
  if len(program) >= 2 and program[-1][0] == program[-2][0]:
    return False

  # Reject programs with too many ops that don't produce values
  if get_num_terminal_ops(program) > 2:
    return False

  # The third use of a value is no more interesting than the second
  if get_max_uses(program) >= 3:
    return False

  # Reject nontrivial programs that have unused instructions
  if has_unused_op(program):
    return False

  # Reject programs that have boring MVP uses of MVP defs
  if has_mvp_use(program):
    return False

  # Otherwise if it has multivalue usage it is interesting
  return has_multivalue_use(program)


def make_llvm_type(num_defs):
  if num_defs == 0:
    return 'void'
  else:
    return '{' + ', '.join(['i32'] * num_defs) + '}'


def make_llvm_op_name(num_uses, num_defs):
  return f'op_{num_uses}_to_{num_defs}'


def make_llvm_args(first_use, num_uses):
  return ', '.join([f'i32 %t{first_use + i}' for i in range(num_uses)])


def print_llvm_program(program, name):
  tmp = 0
  def_data = []
  print(f'define void @{name}() {{')
  for uses, defs in program:
    first_arg = tmp
    # Extract operands
    for use in uses:
      ret_type, var, idx = def_data[use]
      print(f'  %t{tmp} = extractvalue {ret_type} %t{var}, {idx}')
      tmp += 1
    # Print instruction
    assignment = ''
    if len(defs) > 0:
      assignment = f'%t{tmp} = '
      result_var = tmp
      tmp += 1
    ret_type = make_llvm_type(len(defs))
    op_name = make_llvm_op_name(len(uses), len(defs))
    args = make_llvm_args(first_arg, len(uses))
    print(f'  {assignment}call {ret_type} @{op_name}({args})')
    # Update def_data
    for i in range(len(defs)):
      def_data.append((ret_type, result_var, i))
  print('  ret void')
  print('}')


def print_header():
  print('; NOTE: Test functions have been generated by multivalue-stackify.py.')
  print()
  print('; RUN: llc < %s -verify-machineinstrs -mattr=+multivalue',
        '| FileCheck %s')
  print()
  print('; Test that the multivalue stackification works')
  print()
  print('target triple = "wasm32-unknown-unknown"')
  print()
  for num_uses in range(MAX_OP_USES + 1):
    for num_defs in range(MAX_PROGRAM_DEFS + 1):
      if num_uses == 0 and num_defs == 0:
        continue
      ret_type = make_llvm_type(num_defs)
      op_name = make_llvm_op_name(num_uses, num_defs)
      args = make_llvm_args(0, num_uses)
      print(f'declare {ret_type} @{op_name}({args})')
  print()


if __name__ == '__main__':
  print_header()
  for i, program in generate_programs():
    if is_interesting(program):
      print_llvm_program(program, 'f' + str(i))
      print()
